import re
import torch
from matplotlib import pyplot as plt, patches

from qllmt.dprofile.profiler import CustomProfiler


def attach_model_profilers(model: torch.nn.Module, regex, use_cuda_sync=True, profile_backward=False):
    """
    Instruments a PyTorch model to measure forward and backward pass times of modules matching a regex.

    Args:
        model (torch.nn.Module): The PyTorch model to profile.
        regex (str): The regular expression to match module names.
        use_cuda_sync (bool): Whether to synchronize CUDA before timing. Default is True.
        backward_hook (bool): Whether to profile backward pass. Default is False.

    Returns:
        dict: A dictionary containing the profilers for matched modules.
    """
    hooks = {}

    pattern = re.compile(regex)

    def attach_hooks(name, module: torch.nn.Module):
        if not name:
            name = '<base>'
        forward_profiler = CustomProfiler(f"{name}.forward", use_cuda_sync=use_cuda_sync)
        backward_profiler = CustomProfiler(f"{name}.backward", use_cuda_sync=use_cuda_sync)

        def pre_forward_hook(module, input):
            forward_profiler.start()

        def forward_hook(module, input, output):
            forward_profiler.stop()

        def pre_backward_hook(module, grad_input):
            backward_profiler.start()

        def backward_hook(module, grad_input, grad_output):
            backward_profiler.stop()

        h1 = module.register_forward_pre_hook(pre_forward_hook, prepend=True)
        h2 = module.register_forward_hook(forward_hook, prepend=False)
        hooks[forward_profiler.name] = (h1, h2, forward_profiler)
        if profile_backward:
            h3 = module.register_full_backward_pre_hook(pre_backward_hook, prepend=True)
            h4 = module.register_full_backward_hook(backward_hook, prepend=False)
            hooks[backward_profiler.name] = (h3, h4, backward_profiler)

    for name, module in model.named_modules():
        if pattern.match(name):
            attach_hooks(name, module)

    return hooks


def detach_model_profilers(hooks):
    """
    Detaches hooks from modules in a PyTorch model.

    Args:
        hooks (dict): A dictionary containing the hooks to detach.
    """
    for hook in hooks.values():
        for h in hook[:2]:
            h.remove()
        hook[2].stop()


def _normalized_layer_name(layer_name, sub=''):
    # remove numbers from layer name
    return re.sub(r'\d+', sub, layer_name)


def visualize_metrics(metrics, title="Profiler Metrics", figsize=(12, 6)):
    """
    Visualizes profiling metrics as a horizontal plot with boxes representing execution spans.

    Args:
        metrics (dict): The profiling metrics dictionary from the Profiler.
        title (str): Title of the plot.
        figsize (tuple): Size of the figure.
    """
    fig, ax = plt.subplots(figsize=figsize)

    # Sort metrics by start time for better visualization
    sorted_metrics = sorted(
        metrics.items(),
        key=lambda item: item[1]["last_start_time"] if item[1]["last_start_time"] is not None else float('inf')
    )
    base_time = min(data["last_start_time"] for _, data in sorted_metrics if data["last_start_time"] is not None)
    sorted_metrics = [
        (key, {
            **data, "last_start_time": (data["last_start_time"] - base_time) * 1e6,
            "last_end_time": (data["last_end_time"] - base_time) * 1e6,
            'mean_time': data['mean_time'] * 1e6
        }) for key, data in sorted_metrics]

    normalize_layer_names = [_normalized_layer_name(key) for key, data in sorted_metrics]
    # Prepare y-axis labels and positions
    y_positions = []
    labels = []
    y_counter = 0
    for i, (key, data) in enumerate(sorted_metrics):
        # Skip metrics with no recorded start or end times
        if data["last_start_time"] is None or data["last_end_time"] is None:
            continue

        start_time = data["last_start_time"]
        end_time = data["last_end_time"]
        duration = data["mean_time"]

        normalized_key = _normalized_layer_name(key)
        # y_n = normalize_layer_names.index(normalized_key)
        if i == normalize_layer_names.index(normalized_key):
            y_n = y_counter
            y_counter += 1
            labels.append(_normalized_layer_name(key))
            y_positions.append(y_n)
        else:
            y_n = normalize_layer_names.index(normalized_key)
        # Add a rectangle to the plot
        ax.add_patch(
            patches.Rectangle(
                (start_time, y_n - 0.4),  # Bottom-left corner
                duration,  # Width (duration)
                0.8,  # Height
                facecolor="skyblue",
                edgecolor="black",
                linewidth=1,
            )
        )

        # Add text label inside the rectangle
        ax.text(
            start_time + duration / 2,  # Center horizontally
            y_n,  # Center vertically
            f"{duration:.0f}us",  # Text
            ha="center", va="center", fontsize=9
        )

        # Track y-axis labels

    # Configure axes
    ax.set_yticks(y_positions)
    ax.set_yticklabels(labels)
    ax.set_xlabel("Time (us)")
    ax.set_title(title)
    ax.grid(True, axis="x", linestyle="--", alpha=0.7)
    ax.set_xlim(left=min(data["last_start_time"] for _, data in sorted_metrics if data["last_start_time"] is not None),
                right=max(data["last_end_time"] for _, data in sorted_metrics if data["last_end_time"] is not None))
    ax.set_ylim(bottom=-1, top=len(y_positions))

    # Show the plot
    plt.tight_layout()
    plt.show()
